import torch.nn as nn

class CAB(nn.Module):


    def __init__(self, channels):
        super().__init__()

        self.pool = nn.AvgPool2d()
        self.mlp = nn.Sequential(
            Flatten(),
            nn.Linear(gate_channels, gate_channels // reduction_ratio),
            nn.ReLU(),
            nn.Linear(gate_channels // reduction_ratio, gate_channels)
        )

